Spaces:
Running
Running
import gradio as gr | |
import asyncio | |
from threading import RLock | |
from pathlib import Path | |
import os | |
from typing import Union | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
server_timeout = 600 | |
inference_timeout = 600 | |
lock = RLock() | |
loaded_models = {} | |
def rename_image(image_path: Union[str, None], model_name: str, save_path: Union[str, None] = None): | |
import shutil | |
from datetime import datetime, timezone, timedelta | |
if image_path is None: return None | |
dt_now = datetime.now(timezone(timedelta(hours=9))) | |
filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png" | |
try: | |
if Path(image_path).exists(): | |
png_path = "image.png" | |
if str(Path(image_path).resolve()) != str(Path(png_path).resolve()): shutil.copy(image_path, png_path) | |
if save_path is not None: | |
new_path = str(Path(png_path).resolve().rename(Path(save_path).resolve())) | |
else: | |
new_path = str(Path(png_path).resolve().rename(Path(filename).resolve())) | |
return new_path | |
else: | |
return None | |
except Exception as e: | |
print(e) | |
return None | |
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py | |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
from typing import Literal | |
def load_from_model(model_name: str, hf_token: Union[str, Literal[False], None] = None): | |
import httpx | |
import huggingface_hub | |
from gradio.exceptions import ModelNotFoundError, TooManyRequestsError | |
model_url = f"https://huggingface.co/{model_name}" | |
api_url = f"https://api-inference.huggingface.co/models/{model_name}" | |
print(f"Fetching model from: {model_url}") | |
headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"}) | |
response = httpx.request("GET", api_url, headers=headers) | |
if response.status_code != 200: | |
raise ModelNotFoundError( | |
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." | |
) | |
p = response.json().get("pipeline_tag") | |
if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.") | |
headers["X-Wait-For-Model"] = "true" | |
kwargs = {} | |
if hf_token is not None: kwargs["token"] = hf_token | |
client = huggingface_hub.InferenceClient(model=model_name, headers=headers, timeout=server_timeout, **kwargs) | |
inputs = gr.components.Textbox(label="Input") | |
outputs = gr.components.Image(label="Output") | |
fn = client.text_to_image | |
def query_huggingface_inference_endpoints(*data, **kwargs): | |
try: | |
data = fn(*data, **kwargs) # type: ignore | |
except huggingface_hub.utils.HfHubHTTPError as e: | |
print(e) | |
if "429" in str(e): raise TooManyRequestsError() from e | |
except Exception as e: | |
print(e) | |
raise Exception() from e | |
return data | |
interface_info = { | |
"fn": query_huggingface_inference_endpoints, | |
"inputs": inputs, | |
"outputs": outputs, | |
"title": model_name, | |
} | |
return gr.Interface(**interface_info) | |
def load_model(model_name: str): | |
global loaded_models | |
global model_info_dict | |
if model_name in loaded_models.keys(): return loaded_models[model_name] | |
try: | |
loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN) | |
print(f"Loaded: {model_name}") | |
except Exception as e: | |
if model_name in loaded_models.keys(): del loaded_models[model_name] | |
print(f"Failed to load: {model_name}") | |
print(e) | |
return None | |
return loaded_models[model_name] | |
def load_models(models: list): | |
for model in models: | |
load_model(model) | |
def warm_model(model_name: str): | |
model = load_model(model_name) | |
if model: | |
try: | |
print(f"Warming model: {model_name}") | |
infer_body(model, model_name, " ") | |
except Exception as e: | |
print(e) | |
def warm_models(models: list[str]): | |
for model in models: | |
asyncio.new_event_loop().run_in_executor(None, warm_model, model) | |
# https://huggingface.co/docs/api-inference/detailed_parameters | |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
def infer_body(client: Union[gr.Interface, object], model_str: str, prompt: str, neg_prompt: str = "", | |
height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1): | |
png_path = "image.png" | |
kwargs = {} | |
if height > 0: kwargs["height"] = height | |
if width > 0: kwargs["width"] = width | |
if steps > 0: kwargs["num_inference_steps"] = steps | |
if cfg > 0: cfg = kwargs["guidance_scale"] = cfg | |
if seed == -1: kwargs["seed"] = randomize_seed() | |
else: kwargs["seed"] = seed | |
try: | |
if isinstance(client, gr.Interface): image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs) | |
else: return None | |
if isinstance(image, tuple): return None | |
return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed) | |
except Exception as e: | |
print(e) | |
raise Exception(e) from e | |
async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0, | |
steps: int = 0, cfg: int = 0, seed: int = -1, | |
save_path: str | None = None, timeout: float = inference_timeout): | |
model = load_model(model_name) | |
if not model: return None | |
task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt, | |
height, width, steps, cfg, seed)) | |
await asyncio.sleep(0) | |
try: | |
result = await asyncio.wait_for(task, timeout=timeout) | |
except asyncio.TimeoutError as e: | |
print(e) | |
print(f"Task timed out: {model_name}") | |
if not task.done(): task.cancel() | |
result = None | |
raise Exception(f"Task timed out: {model_name}") from e | |
except Exception as e: | |
print(e) | |
if not task.done(): task.cancel() | |
result = None | |
raise Exception(e) from e | |
if task.done() and result is not None: | |
with lock: | |
image = rename_image(result, model_name, save_path) | |
return image | |
return None | |
def save_image(image, savefile, modelname, prompt, nprompt, height=0, width=0, steps=0, cfg=0, seed=-1): | |
from PIL import Image, PngImagePlugin | |
import json | |
try: | |
metadata = {"prompt": prompt, "negative_prompt": nprompt, "Model": {"Model": modelname.split("/")[-1]}} | |
if steps > 0: metadata["num_inference_steps"] = steps | |
if cfg > 0: metadata["guidance_scale"] = cfg | |
if seed != -1: metadata["seed"] = seed | |
if width > 0 and height > 0: metadata["resolution"] = f"{width} x {height}" | |
metadata_str = json.dumps(metadata) | |
info = PngImagePlugin.PngInfo() | |
info.add_text("metadata", metadata_str) | |
image.save(savefile, "PNG", pnginfo=info) | |
return str(Path(savefile).resolve()) | |
except Exception as e: | |
print(f"Failed to save image file: {e}") | |
raise Exception(f"Failed to save image file:") from e | |
def randomize_seed(): | |
from random import seed, randint | |
MAX_SEED = 2**32-1 | |
seed() | |
rseed = randint(0, MAX_SEED) | |
return rseed | |
def gen_image(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0, | |
steps: int = 0, cfg: int = 0, seed: int = -1): | |
if model_name in ["NA", ""]: return gr.update() | |
try: | |
loop = asyncio.get_running_loop() | |
except Exception: | |
loop = asyncio.new_event_loop() | |
try: | |
result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width, | |
steps, cfg, seed, None, inference_timeout)) | |
except (Exception, asyncio.CancelledError) as e: | |
print(e) | |
print(f"Task aborted: {model_name}, Error: {e}") | |
result = None | |
raise gr.Error(f"Task aborted: {model_name}, Error: {e}") | |
finally: | |
loop.close() | |
return result | |
def generate_image_hf(model_name: str, prompt: str, negative_prompt: str, use_defaults: bool, resolution: str, | |
guidance_scale: float, num_inference_steps: int, seed: int, randomize_seed: bool, progress=gr.Progress()): | |
if randomize_seed: seed = -1 | |
if use_defaults: | |
prompt = f"{prompt}, best quality, amazing quality, very aesthetic" | |
negative_prompt = f"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], {negative_prompt}" | |
width, height = map(int, resolution.split('x')) | |
image = gen_image(model_name, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale) | |
metadata_text = f"{prompt}\nNegative prompt: {negative_prompt}\nSteps: {num_inference_steps}, Sampler: Euler a, Size: {width}x{height}, Seed: {seed}, CFG scale: {guidance_scale}" | |
return image, seed, metadata_text | |